import torch
from torch import nn


class Rearange(nn.Module):
    def __init__(self,image_size=14,patch_size=7):
        self.h=patch_size
        self.w=patch_size
        self.nw=image_size // patch_size
        self.nh=image_size // patch_size

        num_patches=(image_size // patch_size) **2